{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VERIFICATION TESTING\n",
    "\n",
    "# HER2 One Scanner - Aperio NIH\n",
    "\n",
    "- 5-Fold (80/20) split, No Holdout Set\n",
    "- Truth = Categorical from Mean of 7 continuous scores \n",
    "- Epoch at automatic Stop when loss<.001 change \n",
    "- LeNet model, 10 layers, Dropout (0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/usr/lib/python2.7/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "from keras.callbacks import EarlyStopping\n",
    "from PIL import Image\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Activation, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D, Lambda\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.utils import np_utils\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from sklearn import metrics\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.cross_validation import train_test_split\n",
    "from sklearn.metrics import roc_curve, auc, classification_report\n",
    "import csv\n",
    "import cv2\n",
    "import scipy\n",
    "import os\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For single scanner\n",
    "BASE_PATH = '/home/diam/Desktop/1Scanner_VerificationTest_HER2data/Aperio_NIH/'\n",
    "#BASE PATH for working from home:\n",
    "#BASE_PATH = '/home/OSEL/Desktop/HER2_data_categorical/'\n",
    "\n",
    "batch_size = 32\n",
    "num_classes = 3\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Data - Practice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#This is the version from Ravi's code:\n",
    "\n",
    "#FDA\n",
    "#X_FDA = []\n",
    "#idx_FDA = []\n",
    "#for index, image_filename in list(enumerate(BASE_PATH)):\n",
    "#\timg_file = cv2.imread(BASE_PATH + '/' + image_filename)\n",
    "#\tif img_file is not None:\n",
    "\t\t#img_file = smisc.imresize(arr = img_file, size = (600,760,3))\n",
    "#\t\timg_file = smisc.imresize(arr = img_file, size = (120,160,3))\t\t\n",
    "#\t\timg_arr = np.asarray(img_file)\n",
    "#\t\tX_FDA.append(img_arr)\n",
    "#\t\tidx_FDA.append(index)\n",
    "\n",
    "#X_FDA = np.asarray(X_FDA)\n",
    "#idx_FDA = np.asarray(idx_FDA)\n",
    "#random.seed(rs)\n",
    "#random_id = random.sample(idx_FDA, len(idx_FDA)/2)\n",
    "#random_FDA = []\n",
    "#for i in random_id:\n",
    "#\trandom_FDA.append(X_FDA[i])\n",
    "\n",
    "#random_FDA = np.asarray(random_FDA)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Data - Real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(folder):\n",
    "    X = []\n",
    "    y = []\n",
    "    filenames = []\n",
    "\n",
    "    for hclass in os.listdir(folder):\n",
    "        if not hclass.startswith('.'):\n",
    "            if hclass in [\"1\"]:\n",
    "                label = 1\n",
    "            else: #label must be 1 or 2\n",
    "                if hclass in [\"2\"]:\n",
    "                    label = 2\n",
    "                else:\n",
    "                    label = 3\n",
    "            for image_filename in os.listdir(folder + hclass):\n",
    "                filename = folder + hclass + '/' + image_filename\n",
    "                img_file = cv2.imread(folder + hclass + '/' + image_filename)\n",
    "                \n",
    "                if img_file is not None:\n",
    "                    img_file = scipy.misc.imresize(arr=img_file, size=(120, 160, 3))\n",
    "                    img_arr = np.asarray(img_file)\n",
    "                    X.append(img_arr)\n",
    "                    y.append(label)\n",
    "                    filenames.append(filename)\n",
    "\n",
    "    X = np.asarray(X)\n",
    "    y = np.asarray(y)\n",
    "    z = np.asarray(filenames)\n",
    "    return X,y,filenames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python2.7/dist-packages/ipykernel_launcher.py:20: DeprecationWarning: `imresize` is deprecated!\n",
      "`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
      "Use ``skimage.transform.resize`` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "241\n",
      "241\n",
      "[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
      " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
      " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
      " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
      " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
      " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "241\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n"
     ]
    }
   ],
   "source": [
    "X, y, z = get_data(BASE_PATH)\n",
    "\n",
    "\n",
    "#print(X)\n",
    "#print(y)\n",
    "#print(z)\n",
    "print(len(X))\n",
    "print(len(y))\n",
    "print(y)\n",
    "print(len(z))\n",
    "\n",
    "#INTEGER ENCODE\n",
    "#https://machinelearningmastery.com/how-to-one-hot-encode-sequence-data-in-python/\n",
    "encoder = LabelEncoder()\n",
    "y_cat = np_utils.to_categorical(encoder.fit_transform(y))\n",
    "print(y_cat)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Old Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#encoder = LabelEncoder()\n",
    "#encoder.fit(y)\n",
    "\n",
    "#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)\n",
    "\n",
    "#encoded_y_train = encoder.transform(y_train)\n",
    "#encoded_y_test = encoder.transform(y_test)\n",
    "\n",
    "#y_train = np_utils.to_categorical(encoded_y_train)\n",
    "#y_test = np_utils.to_categorical(encoded_y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit Model with K-Fold X-Val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "KFold(n_splits=5, random_state=5, shuffle=True)\n"
     ]
    }
   ],
   "source": [
    "kf = KFold(n_splits = 5, random_state=5, shuffle=True)\n",
    "print(kf.get_n_splits(y))\n",
    "print(kf)\n",
    "\n",
    "\n",
    "#for train_index, test_index in kf.split(y):\n",
    "#    X_train, X_test = X[train_index], X[test_index]\n",
    "#    print(train_index, test_index)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fold #1\n",
      "Epoch 00081: early stopping\n",
      "Fold Score (accuracy): 0.69387755102\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n",
      "fold #2\n",
      "Epoch 00060: early stopping\n",
      "Fold Score (accuracy): 0.8125\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n",
      "fold #3\n",
      "Epoch 00074: early stopping\n",
      "Fold Score (accuracy): 0.833333333333\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n",
      "fold #4\n",
      "Epoch 00033: early stopping\n",
      "Fold Score (accuracy): 0.708333333333\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n",
      "fold #5\n",
      "Epoch 00071: early stopping\n",
      "Fold Score (accuracy): 0.8125\n",
      "[[ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  1.  0.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 0.  0.  1.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 1.  0.  0.]]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "oos_y = []\n",
    "oos_pred = []\n",
    "fold = 0\n",
    "\n",
    "\n",
    "for train, test in kf.split(y_cat):\n",
    "    fold+=1\n",
    "    print(\"fold #{}\".format(fold))\n",
    "        \n",
    "    X_train = X[train]\n",
    "    y_train = y_cat[train]\n",
    "    X_test = X[test]\n",
    "    y_test = y_cat[test]\n",
    "    \n",
    "    #encoder = LabelEncoder()\n",
    "    #encoder.fit(y_test)\n",
    "    #y_train = np_utils.to_categorical(encoder.transform(y_train))\n",
    "    #y_test = np_utils.to_categorical(encoder.transform(y_test))\n",
    "    \n",
    "    model = Sequential()\n",
    "    model.add(Lambda(lambda x: x * 1./255., input_shape=(120, 160, 3), output_shape=(120, 160, 3)))\n",
    "    model.add(Conv2D(32, (3, 3), input_shape=(120, 160, 3)))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Conv2D(32, (3, 3)))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Conv2D(64, (3, 3)))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Flatten())  # this converts our 3D feature maps to 1D feature vectors\n",
    "    model.add(Dense(64))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(Dropout(0.7))\n",
    "    model.add(Dense(3))\n",
    "    model.add(Activation('softmax'))\n",
    "\n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "                optimizer='rmsprop',\n",
    "                metrics=['accuracy'])\n",
    "    \n",
    "    monitor = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=25, verbose=1, mode='auto')\n",
    "    \n",
    "    model.fit(\n",
    "            X_train,\n",
    "            y_train,\n",
    "            validation_data=(X_test,y_test),\n",
    "            callbacks=[monitor],\n",
    "            shuffle=True,\n",
    "            batch_size=batch_size,\n",
    "            verbose=0,\n",
    "            epochs=1000)\n",
    "        \n",
    "    pred = model.predict(X_test)\n",
    "        \n",
    "    oos_y.append(y_test)\n",
    "    pred = np.argmax(pred,axis=1)\n",
    "    oos_pred.append(pred)\n",
    "    \n",
    "        #measure the fold's accuracy\n",
    "    y_compare = np.argmax(y_test,axis=1) #for accuracy calculation\n",
    "    score = metrics.accuracy_score(y_compare, pred)\n",
    "    print(\"Fold Score (accuracy): {}\".format(score))\n",
    "    print(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
